syntax = "proto3";

package tensorflow.tpu;

import "google/protobuf/wrappers.proto";

message ClippingLimits {
  google.protobuf.FloatValue lower = 1;  // -inf if not set
  google.protobuf.FloatValue upper = 2;  // +inf if not set
}

// Dynamic learning rate specification in the TPUEmbeddingConfiguration. The
// actual learning rates are provided as a scalar input list to the
// SendTPUEmbeddingGradients Op indexed by their tag specified through the
// following proto.
message DynamicLearningRate {
  // For tables where learning rates are dynamically computed and communicated
  // to the TPU embedding program, a tag must be specified for the learning
  // rate.
  //
  // The tag must be a non-negative  integer. The total number of unique tags
  // must be less than or equal to the number of tables in the TPU embedding
  // configuration (a table does not specify any tag if it uses a constant
  // learning rate, and specifies exactly one tag if it uses dynamic learning
  // rates).
  //
  // All tags in the range [0, number_of_unique_tags) must be present in the TPU
  // embedding configuration, i.e. a tag cannot be skipped if a different tag
  // numerically greater than it is used in the configuration.
  //
  // If multiple tables specify the same tag, they *MUST* have
  // the same dynamic learning rate, for example, their dynamic learning rate
  // could be computed by the same TensorFlow sub-graph. The partitioning of the
  // embedding layer would be more optimal if the number_of_unique_tags is as
  // *LOW* as possible, i.e., if many tables share the same tag.
  //
  // The learning_rate input of the SendTPUEmbeddingGradients op is used to
  // communicate dynamic learning rates to the TPU embedding program.
  // The learning_rate input is a list of scalars where the size of the list is
  // equal to the number of unique tags. The learning rate associated with a
  // particular tag is specified by populating its corresponding index in the
  // list of learning_rate scalars.
  int32 tag = 1;
}

// Source of learning rate to use.
message LearningRate {
  oneof learning_rate {
    float constant = 1;
    DynamicLearningRate dynamic = 2;
  }
}

// Each optimizer's parameter proto has a link to its documentation and CPU
// implementation (if available) for user reference.

// https://www.tensorflow.org/api_docs/python/tf/train/AdagradOptimizer
// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L151
message AdagradParameters {
  float initial_accumulator = 1;
}

// https://www.tensorflow.org/api_docs/python/tf/train/GradientDescentOptimizer
// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L423
message StochasticGradientDescentParameters {
}

// https://www.tensorflow.org/api_docs/python/tf/train/FtrlOptimizer
// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L192
message FtrlParameters {
  float l1 = 1;
  float l2 = 2;
  float lr_power = 3;
  float initial_accum = 4;
  float initial_linear = 5;
}

// The Adam optimizer does not implement hyper-parameter update; use the dynamic
// learning rate feature instead, setting the learning rate to:
// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
// Here, t is the current timestep.
//
// https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
// https://github.com/tensorflow/tensorflow/blob/ab51450c817674c8ff08a7ae4f8ac50cdc4bed8b/tensorflow/python/training/adam.py#L54
//
// Note that the code by default implements the lazy version of Adam
// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/LazyAdamOptimizer)
// unless the use_non_lazy_adam parameter is set, in which case it implements
// the normal version of Adam that updates all parameters in the embedding
// table, even for entries that are not used in the current minibatch
// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/AdamOptimizer). If
// use_non_lazy_adam is enabled, gradient accumulation is also required to be
// enabled in order to get correct results; a warning will be printed otherwise
// (which may change to an error in the future). If use_sum_inside_sqrt is set,
// the Adam variable update formula will be changed from m / (sqrt(v) + epsilon)
// to m / sqrt(v + epsilon**2); this option improves the performance of TPU
// training and is not expected to harm model quality.
message AdamParameters {
  float beta1 = 3;
  float beta2 = 4;
  float epsilon = 5;
  float initial_m = 6;
  float initial_v = 7;
  bool use_non_lazy_adam = 8;
  bool use_sum_inside_sqrt = 10;
}

// https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer
// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L271
message MomentumParameters {
  float momentum = 1;
  bool use_nesterov = 2;
  float initial_accum = 3;
}

// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L356
message RmsPropParameters {
  float rho = 1;
  float momentum = 2;
  float epsilon = 3;
  float initial_ms = 4;
  float initial_mom = 5;
}

// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L372
message CenteredRmsPropParameters {
  float rho = 1;
  float momentum = 2;
  float epsilon = 3;
  float initial_ms = 4;
  float initial_mom = 5;
  float initial_mg = 6;
}

// Variant of algorithm in http://proceedings.mlr.press/v44/shamir15.pdf
message MdlAdagradLightParameters {
  float l2 = 1;
  float lr_power = 2;
  float min_servable_mdl_benefit = 3;
  float mdl_mix_in_margin = 4;
  float mdl_benefit_rampup_coeff = 5;
  float mdl_min_weight = 6;
  float benefit_revisit_scale = 7;
  float max_event_benefit = 8;
  float max_total_benefit = 9;
  float mdl_hard_limit = 10;
  bool hard_limit_min_benefit = 11;
  bool mdl_regularize = 12;
  float initial_accumulator = 13;
  float initial_weight = 14;
  float initial_benefit = 15;
}

// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L68
message AdadeltaParameters {
  float rho = 1;
  float epsilon = 2;
  float initial_accumulator = 3;
  float initial_update = 4;
}

// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L164
message ProximalAdagradParameters {
  float l1 = 1;
  float l2 = 2;
  float initial_accumulator = 3;
}

// Status of using gradient accumulation (doing two passes over the input
// gradients: one to accumulate them into a temporary array and another to apply
// them using the actual optimization algorithm). The extra message is to wrap
// the enum for scoping.
message GradientAccumulationStatus {
  // if UNSPECIFIED (default), gradient accumulation is ENABLED.
  enum Status {
    UNSPECIFIED = 0;
    ENABLED = 1;
    DISABLED = 2;
  }
};

// Configuration proto for hot ID optimization. This is an experimental feature
// that is currently disabled (by default).
message HotIdOptimizerConfiguration {
  // Whether to enable or disable hot ID optimization.
  // If UNSPECIFIED (default), hot ID optimization is DISABLED.
  enum Status {
    UNSPECIFIED = 0;
    ENABLED = 1;
    DISABLED = 2;
  }
  Status status = 1;

  // The following fields are never expected to be set by the TF model. However,
  // a TF model could set them if it chooses to. If the fields are not set,
  // meaningful default values will be chosen by the TPU software.

  // Frequency above which an embedding ID is classified as hot. The valid
  // range for the frequency is [0.0, 1.0]. The frequency of an embedding ID is
  // defined as the ratio of the number of lookups for that ID to the total
  // number of lookups for the embedding table.
  float frequency_threshold = 2;

  // The maximum number of hot IDs for the embedding table. If greater than
  // max_id_count hot IDs exist for the table, the IDs with the highest
  // frequencies are chosen.
  int32 max_id_count = 3;

  // The maximum number of slots reserved in HBM (across the entire TPU system)
  // for storing the replicas of hot IDs for the embedding table. In future, the
  // number of replicas for a particular hot ID could be adjusted based on its
  // frequency. The max_slot_count value captures the total number of replicas
  // across all hot IDs for the table.
  int32 max_slot_count = 4;
}

message OptimizationParameters {
  // Learning rate used for updating the embedding layer parameters.
  LearningRate learning_rate = 13;
  reserved 1;  // Old learning rate tag.

  // Limits to which to clip the weight values after the backward pass; not
  // present means no limits are applied.
  ClippingLimits clipping_limits = 2;

  // Limits to which to clip the backward pass gradient before using it for
  // updates; not present means no limits are applied.
  ClippingLimits gradient_clipping_limits = 7;

  // Amount of weight decay to apply; see weight_decay_optimizers.py for
  // details. Almost all optimizers are supported with this option (MDL Adagrad
  // Light does not work, and SGD does not behave as expected if it is enabled).
  // Although there is no check, users who want weight decay will probably also
  // want to enable gradient accumulation as well so that the decay will happen
  // once per minibatch.
  float weight_decay_factor = 16;

  // Status of using gradient accumulation (doing two passes over the input
  // gradients: one to accumulate them into a temporary array and another to
  // apply them using the actual optimization algorithm).
  GradientAccumulationStatus.Status gradient_accumulation_status = 17;

  // Configuration proto for hot ID optimization. This is an experimental
  // feature that is currently disabled (by default).
  HotIdOptimizerConfiguration hot_id_optimizer_configuration = 18;

  // Optimization algorithm parameters; which field is selected determines which
  // algorithm to use.
  oneof parameters {
    AdagradParameters adagrad = 3;
    StochasticGradientDescentParameters stochastic_gradient_descent = 4;
    FtrlParameters ftrl = 5;
    AdamParameters adam = 6;
    MomentumParameters momentum = 8;
    RmsPropParameters rms_prop = 9;
    CenteredRmsPropParameters centered_rms_prop = 10;
    MdlAdagradLightParameters mdl_adagrad_light = 11;
    AdadeltaParameters adadelta = 12;
    ProximalAdagradParameters proximal_adagrad = 14;
  }

  reserved 15;  // Old use_gradient_accumulation.
}

// Specification of an optimization algorithm's state variables (both the main
// value vector and any extra accumulators, etc.). This proto is only used
// internally by the TPU software and is not exposed directly to the TF model.
message StateVariableSpecification {
  // Parameter name for the state variable.
  string name = 1;

  // A normal state variable that should be saved and restored in checkpoints
  // and used as an input or output to non-debug TensorFlow ops.
  message UserDefined {
    // For padding embedding rows, this field specifies the initial value to be
    // used. Separate initial values need to be specified for the embeddings and
    // any extra accumulators. The initial values should be specified so as to
    // maintain two invariants during model training:
    // (1) The embedding vector multiplied by zero returns a vector containing
    //     all zeros. To maintain this invariant, the embedding values should
    //     never be NaNs or +-infinity.
    // (2) Repeatedly applying the optimizer using a gradient vector of all
    //     zeros does not cause the embeddings or slot variables to become NaNs
    //     or +-infinity.
    // The padding row is looked up when no embedding IDs are present for a
    // feature. The semantics of embedding lookup dictate that the output must
    // be zero under this scenario.
    double padding_initial_value = 1;
  }

  // A state variable that should be filled with a constant and normally hidden
  // from users (used for intermediate gradients being accumulated, for
  // example).
  message FillWithConstant {
    double initial_value = 1;
  }

  // Usage type of this state variable.
  oneof usage {
    UserDefined user_defined = 2;
    FillWithConstant fill_with_constant = 3;
  }
}
